from tqdm import tqdm
from utils.utils import break_down_into_subquestions, call_model, remove_extra_target_occurrences, mean_pooling, \
    retrieve_facts, call_model_template
import torch
import json
from mquake_dataset import MQUAKE

def retrieve_facts2(query, fact_embs, contriever, tok, k=1, threshold=0.845):
    inputs = tok([query], padding=True, truncation=True, return_tensors='pt').to('cuda')
    with torch.no_grad():
        outputs = contriever(**inputs)
        # print(outputs[0])
        # print(outputs[0].shape)
        query_emb = mean_pooling(outputs[0], inputs['attention_mask']).cpu()
        # print('-' * 100)
        # print(query_emb)
    sim = (query_emb @ fact_embs.T)[0]
    knn = sim.topk(k, largest=True)
    if knn.values[0] < threshold:
        return None
    
    return knn.indices


def get_relation2(subquestion, rels, rel_emb, contriever, tokenizer, retriever_threshold):
    rel_idx = retrieve_facts2(subquestion, rel_emb, contriever, tokenizer, threshold=retriever_threshold)
    if rel_idx is None:
        return None
    rel = rels[rel_idx[0]]
    return rel


def fetch_rel_subj2subq(subject, rel, relation2subq_prompt, sc_end_block, model, model_tokenizer, device, use_template=False):
    prompt = relation2subq_prompt + f"Given this relation: \"{rel}\" and this subject: \"{subject}\",\nThe corresponding question is"
    if use_template:
        output = call_model_template(prompt, sc_end_block, temperature=0.2, generate_length=20, model=model,
                        gptj_tokenizer=model_tokenizer, device=device)
    else:
        output = call_model(prompt, sc_end_block, temperature=0.2, generate_length=20, model=model,
                        gptj_tokenizer=model_tokenizer, device=device)
    output = output.strip().split('\n\n')[5]
    output = output.strip().split('\n')[1]
    output = output.strip().split('\"')[1]
    # print(output)
    
    return output


def get_correct_track(d, edited):
    triples_name = 'new_triples' if edited else 'triples'
    triples = d['orig'][triples_name]
    tracks = [(t[0], t[1]) for t in triples]
    return tracks


def get_fact_form_kg(subject, rel, entity2id, ent2alias, rel2id, kg_s_r_o, id2entity, caseid, track, masking, logger):
    subject_id = None
    if subject in entity2id.keys():
        subject_id = entity2id[subject]
    else:
        for ent in ent2alias.keys():
            if subject in ent2alias[ent]:
                subject_id = entity2id[ent]
                break
    if subject_id is None:
        return "<no fact>", False, None
    
    # rel is retrived using embedding from all rels in the dataset
    if rel is None:
        return "<no fact>", False, None
    
    rel_id = rel2id[rel]
        
    if subject_id in kg_s_r_o.keys():
        if rel_id in kg_s_r_o[subject_id].keys():
            fact_object_id, caseids = list(kg_s_r_o[subject_id][rel_id])
            
            # Our dynamic masking:
            # We mask out the edit that is not from this case and will change the end-object of the s-r-o link
            # Note: we only need to check if the s and r are the same since the edit will always modify the end object.
            if masking and caseid not in caseids and any(t == (subject_id, rel_id) for t in track):
                return "<no fact>", False, None
            
            fact_object = id2entity[list(kg_s_r_o[subject_id][rel_id])[0]]
            fact = f'{rel.format(subject)} {fact_object}'
            return fact, True, fact_object
    
    return "<no fact>", False, None


def fit_subject_on_kg(subject, entity2id, ent_emb, contriever, tokenizer, ents, kg_s_r_o, ent2alias, threshold=1000):
    if subject in entity2id.keys():
        return subject
    if len(ents) == 0:
        return subject
    indices = retrieve_facts(subject, ent_emb, contriever, tokenizer, k=min(10, len(ents)))
    for idx in indices:
        target = ents[idx]
        if target in kg_s_r_o.keys():
            return target
        if target in ent2alias.keys() and subject in ent2alias[target]:
            return target
    return subject


def gwalk_eval_loop(dataset, task_prompt, sc_facts, model, model_tokenizer, device, rels,
                     rel_emb, logger, rel2subq, retriever_threshold,
                     contriever, tokenizer, entity2id, ent2alias, rel2id, kg_s_r_o,
                     id2entity, ent_emb, ents, sc_done, sc_end_block, relation2subq_prompt,
                     rand_list, print_prompt, breakdown_prompt, masking, result_file_path, pre_token_length=4):
    cor = 0
    tot = 0
    h_cor = 0
    print_prompt = False
    raw_answer_dict = {}

    for d in tqdm(dataset.get_dataset()):
        if print_prompt:
            logger.info("=" * 50, f"Caseid = {d['case_id']}", "=" * 50)
        tot += 1
        edit_flag = d['case_id'] in rand_list
        raw_answer_dict[d['case_id']] = {'edited': edit_flag}
        
        start_subject, breakdown_rels_list = break_down_into_subquestions(d, breakdown_prompt, sc_done,
                                                                          model_tokenizer, model)
        llm_answers = []
        for q_id, q in enumerate(d["questions"]):
            if print_prompt:
                logger.info("=" * 30, f"q_id = {q_id + 1}", "=" * 30)
            subject = start_subject
            breakdown_rels = breakdown_rels_list[q_id]
            found_ans = False
            ans = None
            prompt = task_prompt + "\n\nQuestion: " + q + "\n"
            for i in range(len(breakdown_rels)):
                relation = breakdown_rels[i]
                if len(rels):
                    rel = get_relation2(relation, rels, rel_emb, contriever, tokenizer, retriever_threshold)
                else:
                    rel = None
                
                if rel is None:
                    subquestion = fetch_rel_subj2subq(subject, relation, relation2subq_prompt,
                                                      sc_end_block,
                                                      model=model,
                                                      model_tokenizer=model_tokenizer, device=device)
                else:
                    subquestion = rel2subq[rel].format(subject)
                # subquestion = rel.format(subject)
                prompt = prompt + "Subquestion: " + subquestion + "\n"
                
                prompt = call_model(prompt, sc_facts, model=model, temperature=0.2, device=device,
                                gptj_tokenizer=model_tokenizer)
                
                if prompt.strip().split('\n')[-1] == 'Retrieved fact:':
                    prompt = prompt[:-len('\nRetrieved fact:')]
                prompt = remove_extra_target_occurrences(prompt, "Question: ", 5)[pre_token_length:]
                
                temp_split = prompt.strip().split('\n')
                # otherwise, extract the generated subquestion
                if len(temp_split) < 2:
                    break  # failed case
                
                generated_answer = temp_split[-1][len("Generated answer: "):]
                
                # Genertaed answer: XX is {}. YY
                ga_seg = generated_answer.strip().split('. ')
                
                if len(ga_seg) >= 2:
                    answer_object = ". ".join(ga_seg[1:])
                    logger.info("answer_object: %s" % answer_object)
                else:
                    break
                
                fact_sent, contra_or_not, fact_object = get_fact_form_kg(subject, rel, entity2id, ent2alias,
                                                                         rel2id, kg_s_r_o, id2entity, d['case_id'],
                                                                         get_correct_track(d, edit_flag), masking,
                                                                         logger)
                
                # check whether there is a contradiction:
                # contra_promt = "Retrieved fact {} to generated answer, so the intermediate answer is: {}\n"
                contra_promt = "Retrieved fact {} to generated answer, so continue with this subject: {}.\n"
                if contra_or_not:
                    does_or_doesnot = "contradicts"
                    inter_answer = fact_object
                else:
                    does_or_doesnot = "does not contradict"
                    inter_answer = answer_object
                
                contra_promt = contra_promt.format(does_or_doesnot, inter_answer)
                
                # reset pointer and var for the next hop:
                subject = fit_subject_on_kg(inter_answer, entity2id, ent_emb, contriever, tokenizer, ents,
                                            kg_s_r_o, ent2alias)
                ans = subject
                prompt = prompt + '\nRetrieved fact: ' + fact_sent + '.\n' + contra_promt
                
                # print("=" * 20, f"hop {i+1}", "=" * 20)
            
            
            # if not found_ans:
            #     continue
            if print_prompt:
                logger.info("=" * 20, "End", "=" * 20)
            logger.info(prompt[len(task_prompt) + 2:])

            llm_answers.append(ans)
            
            # acc:
            if dataset.check_answer(edit_flag, d, ans):
                cor += 1
                if dataset.verify_subquestion_path(prompt[len(task_prompt):], d, edit_flag):
                    h_cor += 1
                break

        raw_answer_dict[d['case_id']]['answers'] = llm_answers
        
        if tot % 10 == 0:
            with open(result_file_path, 'w') as f:
                json.dump(raw_answer_dict, f)
        
        logger.info("%s (%s), %s" % (cor, h_cor, tot))
    
    logger.info(f'Multi-hop acc = {cor / tot} ({cor} / {tot})')
    
    with open(result_file_path, 'w') as f:
        json.dump(raw_answer_dict, f)